
import numpy as np
import pandas as pd

from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt

from config import DATA_DIR, LABEL_VALUES


def cut(df):
    ls = ["positive", "neutral", "negative"]
    dfs = [df.query(f"sentiment == '{sen}'").sample(frac=1).head(666) for sen in ls]
    df = pd.concat(dfs, axis=0)    
    return df
        


# eval_data_true_labels = list(eval_data["sentiment"])
# eval_data_true_labels_num = [LABEL_VALUES.index(lab) for lab in eval_data_true_labels]



def obtain_probs():
    from predict import SentimentPredictor
    from bert_preprocess import BertPreprocessor

    # Getting Probabilities 

    data_copy = pd.read_csv(f"{DATA_DIR}/data_copy.csv")
    data_copy = data_copy.head(6630)
    data_copy.id = data_copy.id.astype(int)
    data_copy = data_copy[["segment", "sentiment", "id"]]

    # eval_data = data_copy[data_copy['segment'].map(len) > 40]
    # eval_data = data_copy[data_copy['segment'].map(len) < 130]
    # eval_data = cut(eval_data)

    eval_sents = list(eval_data["segment"])

    sent_pred = SentimentPredictor()
    sent_pred.load_model()

    preds = []
    prop_preds = []
    props_pred_all = []

    eval_data = eval_data.rename({"sentiment": "true_label"}, axis=1)
    print(eval_data)
    eval_data["sentiment"] = list(range(len(eval_data)))
    eval_data["probabilities"] = list(range(len(eval_data)))
    # 
    for s in eval_sents:
        s_prop_pred = sent_pred.predict([s], pretty=False)
        props_pred_all.append(s_prop_pred)
        prop_val = max(s_prop_pred)
        prop_preds.append([prop_val, 1-prop_val])
        s_label = LABEL_VALUES[s_prop_pred.argmax()]
        preds.append(s_label)


    eval_data["sentiment"] = preds
    eval_data["probabilities"] = props_pred_all

    preds = sent_pred.predict(eval_sents, pretty=True) #, sp=True)
    print(prop_preds)
    eval_data.to_csv("eval_data.csv", index=False, sep="\t")



def plot_roc_curve():

    eval_data = pd.read_csv("eval_data.csv", sep="\t")

    print(eval_data["probabilities"][0])
    eval_data["probabilities"] = eval_data["probabilities"].apply(lambda x: [*map(float, [el for el in x[1:len(x)-1].split(" ") if el ])])
    print(eval_data)

    eval_data_true_labels = list(eval_data["true_label"])
    eval_data_true_labels_num = [LABEL_VALUES.index(lab) for lab in eval_data_true_labels]
    prop_preds = list(eval_data["probabilities"])
    print(prop_preds)
    preds = eval_data_true_labels

    # Plot ROC AUC
    _, ax = plt.subplots()
    for ix in range(3):
        fpr, tpr, threshold, roc_auc = process_for(ix, eval_data_true_labels_num, preds, eval_data)
        ax.plot(fpr, tpr, label="ROC: " + LABEL_VALUES[ix])

        
    ax.plot([0.05, 0.95], [0.05, 0.95], transform=ax.transAxes, label="Random classifier", color="red")
    ax.legend(loc=4)
    ax.set_xlabel("False positive rate")
    ax.set_ylabel("True positive rate")
    ax.set_title("ROC curves")
    plt.show()


def process_for(i, eval_data_true_labels_num, preds, eval_data):
    fil = lambda x: [1 if e == i else -1 for e in x]    
    preds_num = [LABEL_VALUES.index(lab) for lab in preds]

    y_true = np.array(fil(eval_data_true_labels_num))
    print(y_true)
    # y_true = np.array(eval_data_true_labels)
    prop_preds_for_label = [el[i] for el in list(eval_data["probabilities"])]
    y_pred = np.array(prop_preds_for_label)
    print(y_pred)    
        
    fpr, tpr, threshold = roc_curve(y_true, y_pred)
    roc_auc = auc(fpr, tpr)

    return fpr, tpr, threshold, roc_auc #, prec_sc


if __name__ == "__main__":
    plot_roc_curve()
